import gym
import torch

from continual_rl.utils.env_wrappers import OldInterfaceWrapper, FrameStack
from continual_rl.utils.utils import Utils

from continual_rl.experiments.tasks.preprocessor_base import PreprocessorBase

from continual_rl.experiments.tasks.task_base import TaskBase


class MuJoCopToPyTorch(gym.ObservationWrapper):
    def observation(self, observation):
        # 将观测处理为PyTorch的tensor
        processed_observation = torch.tensor(observation)
        return processed_observation


class MuJoCoPreprocessor(PreprocessorBase):
    def __init__(self, env_spec):
        self.env_spec = self._wrap_env(env_spec)
        dummy_env, _ = Utils.make_env(self.env_spec)  # 创建一个环境获取其观测空间；如果环境包装未改变环境的观测空间，则不需要
        super().__init__(dummy_env.observation_space)

    def _wrap_env(self, env_spec):
        # 定义需要对环境进行包装的方式
        gym_env_spec = lambda: FrameStack(OldInterfaceWrapper(MuJoCopToPyTorch(Utils.make_env(env_spec)[0])), 1)
        return gym_env_spec

    def preprocess(self, batched_obs):
        # 定义对观测进行预处理的方式
        return torch.stack([obs.to_tensor() for obs in batched_obs])

    def render_episode(self, episode_observations):
        # TODO 采用render得到图像
        return None


class MuJoCoTask(TaskBase):
    def __init__(self, task_id, action_space_id, env_spec, num_timesteps, eval_mode):
        preprocessor = MuJoCoPreprocessor(env_spec)

        # preprocessor中会对环境的状态空间和动作空间进行修改，比如通过包装器
        # 因此需要重新获取状态空间，动作空间和spec
        dummy_env, _ = Utils.make_env(preprocessor.env_spec)
        env_spec = preprocessor.env_spec
        observation_space = dummy_env.observation_space
        action_space = dummy_env.action_space

        super().__init__(task_id, action_space_id, preprocessor, env_spec, observation_space, action_space,
                         num_timesteps, eval_mode)
